{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3333333432674408\n"
     ]
    }
   ],
   "source": [
    "#L1范数损失\n",
    "x = torch.Tensor([1,2,3])\n",
    "target = torch.Tensor([2,2,3])\n",
    "criterion = nn.L1Loss()\n",
    "loss = criterion(x,target)\n",
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.16833333671092987\n"
     ]
    }
   ],
   "source": [
    "#SmoothL1Loss\n",
    "x = torch.Tensor([1,2.1,3])\n",
    "target = torch.Tensor([2,2,3])\n",
    "criterion = nn.SmoothL1Loss()\n",
    "loss = criterion(x,target)\n",
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.33666667342185974\n"
     ]
    }
   ],
   "source": [
    "#MSE Loss\n",
    "x = torch.Tensor([1,2.1,3])\n",
    "target = torch.Tensor([2,2,3])\n",
    "criterion = nn.MSELoss()\n",
    "loss = criterion(x,target)\n",
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1774960607290268\n",
      "0.5726836323738098\n"
     ]
    }
   ],
   "source": [
    "#BCE二分类损失\n",
    "target =torch.Tensor([0,0,1,1,0,1])\n",
    "predict = torch.Tensor([0.3,0.2,0.8,0.9,0.1,0.95])\n",
    "criterion = nn.BCELoss()\n",
    "loss = criterion(predict,target)\n",
    "print(loss.item())\n",
    "bce_with_logits_loss = nn.BCEWithLogitsLoss()\n",
    "print(bce_with_logits_loss(predict,target).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.9459, 0.6897])\n",
      "tensor(0.8178)\n",
      "tensor(1.6356)\n"
     ]
    }
   ],
   "source": [
    "#CrossEntropyLoss\n",
    "predict = torch.Tensor([[0.1,0.5,0.4],[0.1,0.1,0.8]])\n",
    "label = torch.LongTensor([1,2])\n",
    "loss = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "print(loss(predict,label))\n",
    "loss = nn.CrossEntropyLoss(reduction=\"mean\")\n",
    "print(loss(predict,label))\n",
    "loss = nn.CrossEntropyLoss(reduction=\"sum\")\n",
    "print(loss(predict,label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step1 softmax:tensor([[0.2603, 0.3883, 0.3514],\n",
      "        [0.2491, 0.2491, 0.5017]])\n",
      "step2 log:tensor([[-1.3459, -0.9459, -1.0459],\n",
      "        [-1.3897, -1.3897, -0.6897]])\n",
      "step3 nll_loss:0.8178186416625977\n"
     ]
    }
   ],
   "source": [
    "#softmax/log/nllloss\n",
    "import torch.nn.functional as F\n",
    "predict = torch.Tensor([[0.1,0.5,0.4],[0.1,0.1,0.8]])\n",
    "label = torch.LongTensor([1,2])\n",
    "softmax = torch.softmax(predict,dim=1)\n",
    "print(\"step1 softmax:{}\".format(softmax))\n",
    "_log = torch.log(softmax)\n",
    "print(\"step2 log:{}\".format(_log))\n",
    "nll_loss = F.nll_loss(_log,label)\n",
    "print(\"step3 nll_loss:{}\".format(nll_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.42444947361946106"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#KLDivLoss\n",
    "predict = torch.Tensor([0.1,0.5,0.4])\n",
    "label = torch.Tensor([0.4,0.5,0.1])\n",
    "loss = nn.KLDivLoss()\n",
    "loss(label,predict).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3928571045398712"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#CosineEmbeddingLoss\n",
    "x = torch.Tensor([[0.1,0.5,0.4],[0.1,0.5,0.4]])\n",
    "y = torch.Tensor([[0.4,0.5,0.1],[0.1,0.5,0.4]])\n",
    "label = torch.Tensor([-1,1])\n",
    "loss = nn.CosineEmbeddingLoss()\n",
    "loss(x,y,label).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAksAAABgCAYAAAANSTbeAAAY/0lEQVR4nO2dbWhc153Gn24XfTBoKQgM1mKQ7GXdXZhdgTMYNG5AqsEqxh0CURPKaMFSQ9buRqNA4gYy8gdL2bZKupa8ZEVwJH/QUJI6u+lgnGpZI0HWIzATuk7nQ+qCLYFBLsoKUgz+IDb23nPuOTN37ttcSfNy5+r5mcEzc+/cOedq7rnP+b+dbzx58uQpCCGEEEKIK3/W7AYQQgghhIQZiiVCCCGEEB8olgghhBBCfKBYIoQQQgjxgWKJEEIIIcQHiiVCCCGEEB8olgghhBBCfKBYIoQQQgjxgWKJEEIIIcQHiiVCCCGEEB8olgghhBBCfKBYIoQQQgjxgWKJEEIIIcQHiiVCCCGEEB8olgghTWQTuXQc8bh6pHPGO4QQEi4olgghTUIIpQFMHp5HoVAwHovIYBIDFEyEkJBBsUQIqTub18ecVqM7WUyupDA/GlNvdCD5ZgaJlWXc2mhCIwkhxAOKJUKIKWakK2wGReuGOzPq/THkdipgNnKYuAhk3kwacqhM8dMs0NuFTuu++w+hG3ks36ZtiRASHv682Q0ghDSfjmN9SBgiJY8slu6kEesx35eCRjA0guT+YMcqXo5jeMH5fv5UHJO9GSzOCNG0ifv3jDcPH6oQUDCkU1cvkF1dF63acX8IIaSW0LJECAH2JzEyZD7NfqptS0UsKdGTejbm/Iy0OtksUQax0YKKQTIeV1Li05jXr2eSlECEkJaDYokQIok9mzKfLCyZAujOEqRdqTeDVI9zf2l1GuqHi4xSbCJ3NYvUlbTPPu4kujur70QIIQ2CbjhCGomwxry0isyNEaydGoZyciFxYRHTp5tsc+lJIdMrgq6FKy6F+1fN1iVOHHe1BkkLksv7DjfcSrzUT5TccB04dNh4fe++IaliluOvY23F+O9ETXpECCE1gZYl0lK4ZlW1EGYMUB6Tp5bQr9Ple413LmYd7qzG04HjJxLyWfbqBJaFaEEKI9sUcaYbbh7SAXfF4pKzueE6u43vWlnDuvXDG/exigT6jtFZRwgJDxRLpEUwixcOXMw3uyG7QMUADYm6Qto1pSws9qywJtFxekSKHKyIYG94utmE9Wjsuo9kFS48D/dd5XdlMXxZy0Tjb/zWJPLbCCYnhJBGQDccCT3CmtTaIkkhY4ASyLxolR9mVpiXq6vxxNA/BGSlG83eVo0QfQn03fBqsRmrlDixWKVPMaQL80B8GHHtthNCcnS7EU6EEFJfKJZIiFEVnlfMV6kLGaxenESryiazrlAGx61Wk41bWF4xhMeb4ZBKFfT2VbZVI61GfVj0tP50IDlTQDLQlwjBVEB6p20khJAGQDccCT9D5nIY6WPNbshuMF1wdgvS5u1l5L1ESTPYyGFOlws4457mL0RfeCxhhBBSf2hZIiFmOxaKkKNdcBWBy5u4dTNvCI/x5gsPmaWXLb/2jDeq5oIjhJDoQcsSIQ3AdMHZLEhybbQQZn6V0vtdUC640FjCCCGkAdCyREgDcK1J1JNGwa1QUTOQbakeOWS64KoFbhNCSLSgWCKEBOPODIbvZbA4SqlECNlb0A1HCPFF1FSKx43H1S5v9xwhhEQYWpYIIb5IF+Jos1tBCCHNg5YlQgghhBAfKJYIIYQQQnygWCKEEEII8aGlxZIOPJ250+yWEEIIISSqMMC77hQxEx9GFinMl1aaby5CZA4vAKkrBaR9VoUnfui/a1DC8/cPRoj6Z68u3nLnshK9MHTiwiKmT1tyCzdyGDtVufZhTa9Refxl9N2YRjIURUX12o8JZELTphZD/2bcCsn6bbOtuyng/cCflrYstQLFy+ESSgKR3TQvVpZ/aQy5jWa3ZhvsT2K6UEAhFOnrnejqbXYb6knU+9ckjBvYhJtQEoLwlHOR6OxLccQvF3f9tUKguR2/mWxen6BQ2i1iTLyRQWJlEgP234nfNrJtaFmqJ6KIn1g89UIqNEJJE3vRuIgWJjH5QRHJ0bC1rhUQ69bNY01ZX6I3Kwtf/xwCo+UwZvNvCcFiTJ4q+lHEjLKcVfRRW9QW5pB7cYeCwsVaFQqUaBSLZFMo7RJDFI0MGX9jt9+J57byupva0kn8oWWpbhgD41V5m8FIGAd4eRHBHIhbyboUKmJIX0nJZy1npQtE1PvXYORagM7J0+b1OdPdaQiHCjHYk5YWYBhSR0xqtkuFNWkog0yILIXFD0S7Esi8yIlaLZCTX4/fid82EhyKpXqxcQvL9oHx0SoKtwulR/HBVuVnttZRtGy/u7FlP6oT22eKD10+85Xlez9bxSP1Ni+iGmDc0BYvJCBvaKdmELkzGfX+NRC5mLJj8rSJWzeFnHEXDuY1anDvvrHnThAhAKKo6PEdfbo+FLG0YPw3NGKxdGxh84uCZXy8i82vKz/16F7BdRxrKbY2cdcyXhf+YP+rPsKqdftawF76TX45Ma4JEXbDOQPYBH6mfKc50j3WKMh+m7eX5YwuddDyXe1/gc0PUxj/byVous4i+8EwjnzTfLn6cQbD7xQt2+LVu9kG/P69c5jSd7GTU1ia7EO7ZZe7vx7DuXfX5fPOl7PIPaM27D+OPmO2mZcDcSwEcUCtScfpcWRuit9aFsOX+40bU7Rmy2HuX2Wygi0ofWje0lZ7wLp3HKGrW6LiWLAEnbvE3JRcX9ZtSiD0dqGz4sDrWJNjVDcOubmj9h8ythjX6MoaxBW8nWu04/Q0Cqe38YFGcWdJ/h0S3dYz0Yb2//sMb//TPFbV64GfLmLihBrJHucxOzyGa4/L2wKMjuGjrR1bd97GuffNXmLfAKZyE+j7lvlya2UWqfQ1bJW2Be9l7NmUIYiyWBM/lP3Bt5FgRNOyJAayuFMoCfIXB4xt9hmyEFZxF7+tcXNI5ywzuuD7mbPFFPor4jw6MPD6eXOmKFibQ3ZZzRy+Wsb8v+lWxXD+7bKI8qcTAy/0lV/+ZwG/qzAu3UX+P9bV8yNInThS0Z5Dh43/VpZxK8CMQ5r14/EAj71mgTDje6TDamEYY9d3ZgMILy3QvweGQLFn7xltlcHRGy7bxDXr8jsV4ss1fkMcy3qNWy1uNsus6WISAs4iorRAOHHclrF03xQHDhGl0YH2q7gfEauAaWFLoO9YpfRri43g9R/qs7CFxfeu4a6yLt391YwSSsZ+z09h/EQ7WpM2xIZeR+qAevl4EbO/vms+//ouspeUUDL2G/zpeElEORHiP155LXZ2yXtL9lOX0ddvGwlE9MSSGBh1mrGYDQoTtH7cUCZtMVBaswNULEHJZF16qBvEdvfTs0W3AfBAEuff1HLJGBAumwPC3Y+msagGg9hr4xjsCt7l9r5BDO7Tr3LI/49FLa39DksP1fNYEgnbcTu7zQF/bR1kV8SQVr+v/MWJCJq7w92/7MVJ4MJi+ZpUsVZS5Agrj3UssIwDSy412oSVynXcWJlE1rJ/x+kRJSAt7g2V1CHGHmtA/OYD05LQfXCv2283cf+e+N/NktaG+MuZspBYm8XcTWMy+dUi5t7VlphBTL2SEAb11mVfHGcvpEr3htV357D4FfDo5hxm18z3hCAc6/XppRTfNsGprJCuLlu/bSQQkRNLelbnMJsLdCqleL6w5LR+OMSNcYNwS1MPut/hQ65m887TaZzXTXtoDAi/zGLuPaVWejOYfL7bq3vutMWRHCrPyHK375Y2rd7OQb/qe2HAId46Dprftfqg+iUkzfqFQoBHeMokNBTj9zUe5fieMPfPMzga8ppatI4FpX44Z9qirIYj60/HfMB+negAeON8vCWsTjqrzZhMuboqE+hyNx/tPbwsaUJIvDFYEkPL72aR/XAWy/JVJ1K/GENin9sHW4u2Z87i/POlXmL2Q6Of75q9xIEUpqsJQuP3XSjYMySVFVK5bCvx20aCEDGxpGctPlkWKk6nYlbZkzIzRUQ9CuFGqnCpWQi6nzate/HNbgy+cRZaEi1fnlGDgdHuN5LoDOR+q+TIyXIg+da/G0JQmq9X8dmNklRC/7FWNV23DkJQmjdpu2s2GoS1f6lnnde7aTV1cX2hPEnwxnS5a9eytBa5oUWZMSZMpOdUmQXnZGF9tXVTs/VKCVUfgX4POkbLm7bes5g8qaTCw3nMvK/iLX+YwdlnmmNTqu05ELQh8bJxH1HCb/39GcxLD4AhCC+cRTwCgjBqREwsVQmWlKg4Hdt7oubEopptlsRQ3J4uHXS/ABzuw/f/xvbegb/Dt3cafHdwAD84oZ4/XkLhD8b/Dz7Db74w32p7YRD9nv5vUktKGUw2t01UiHL/yjdF95hHN2Kjphs+v5J3uN8C4znjDzKmRY12JE5+z/Hukb/9dmu73+x8K4GB79rfPIIjf1WllzIml6U8Gk3EsuG0qdEMhoy5Di5l65PdJF7OHtGZdKa74ZDNrVR1P53B4tPS9U9mMKuFzL42bD3eki65n3/Uh/kXtumGk7Sj//Qgxm9eE0eXrrhkR165StqQ/M7f73qgCV68bHsVy8XNqdYId2Bz0IUHzczLZhdyrD3R7Z/4fUsLksvyEH6//VKdJIFHAUlt5XJQGis8xqyqAeD1R7gmC6O1OlqAyvCPi5i9nDOfG2NjmzE2iijM5X+eRf475y1uOEvGsy3sQmdJuv0td0Jtz4HJ1p1ZzFw3n5fuAVjGxHt59P/E2w0nA+SNfh3fM+I5HETMsqStRnks3/Ywhqr6R9WsT8KCpN0NbkGggfbzCqZ7mMPUO3mV9dCH8V+MlYRF8Z0JXFvz+j5/2p7pKwV6ry/lMLesBvcDI0h6mK8ZeFpbipfLg3drV5t2J8r9066y1BnnOlpmdqsLxixfiihxU76h6pZ5xnO5JVLE0K8KT7qNWboEiZsrsaXxtKRtofj+BLJr5qvOoWlM6diex9dw/l/1uCnowPETSoRWxKCqMg1w+1uGBCEI38qqcI1OjExNlcburY/OY3rFq8ae2Tf334NPYpHvNhKEiFmWyst4yBIBq7Ygb0vp/4pikcKs+am9foy+4CwWqKD72YLpKn7UX68j9zOjDSrzrfvHIxh4phvtz09j7KMtebypSzkk/mUHsUsq0PuaCBb/4hqu6dZ8P+FZhsC8QQQLPK1X3ZbmWYFqS8ky4Rng29pEvX8aEfSd7rFaKbxccjqgO4HMm8ZN2Zh8mUtLOOtRmTFSeTNAvKfyNmcds2YOWgLMtRAL6yoAO0KXK3G3pG19NovMQrkG0djzcSS+MiaTH01JMbRl/D/73WtIq8mfzEg0zlFWTlbTiIlzp8o0OEu3mEgr4c2+mlicdsYWCu9lSoKw7eQYBo8l8L/nYrgm6+xt4drPZtH3QdoZu6Sy4DLHXFqurZBuiUV+20ggImZZQkWmi0wbtgbfnSpnyjlmxfZ9dV2WiiqzQffTMx5nbZT161PlgXffINI/EHWP2pA4Y6m/tDKFqU/K8y4dRzETID7EGuhtEkPq5BGPvf3SeMNDOY6kup9+O+eqpui1rmQhwghmA0a9f1CF+wS2a3z4XgbzF5xuNHORbDHxGi9d+zp+SRyj4jfY02/GNa262FOs2XkvWcYWVQLFETAu60ZtJ5g4XJjn2cXK9riA2YvZksUpdm7YrDPUlcSrpWzfdWR/Povi49LRlGWunNlo1nHyWpPTtBI201InBOHkL0u9xNiPzCLC3c+9Wi6b8DCLt98vwm5fMl1wfe4uuPU10xDQ7TLz9dtGAhE9sQSd4m6vfSQQA33BWVKgVGCuEllvxbpv0P1EG471SZN8hWlduN/eKpvzE6+NlP3vxoB59sc6VmkL+XemsCiFgRY07rMkB9ZAb0Hse0gc9NhXuySH+kN88yub1H3dqxX7BjxXNaOImVMuhQgjQ9T7pxDp2Fdso4aKebHfYqwxMeMVEy+v9fTUTd2tZAnUmGX/blXPzREXpm98reqaU8KxsmyDsLZMIqtrwh1I4dXn9HjYhtg/jJUyx7CWxcRCWUiUEg7EuTWE5Jyy9NuLXpqYoRNNcyHbBGHn0KtIdqkXbTEMjw6UYpVWFyYwV7TKJT8XnHexz2rbSECePHnylI96PL58+vErR58ePXrp6ee7Oc4fP346etQ4zvTnAT/zp6ef/ER8r/l47Td/8tz3y9yosc/o04//2Oxz5fP47SWzL6+MmufhlY+ffllt38DnqpZ/56NPR3NfNv98RbF/6u/a8ue3Rv34fPpo+K/bhvbB8ht9ZbQJY0CDHvL343XOPn96ybPfftv0feDo00u/DUEfQ/yIpGUpHBgzmDNy/oS53SwPIWeR21ide20Rv7qpnu8bxGCfV22lIrLCrWJ3M4YMbVJPnRk362N5Ls2yidxVFT/SwJXMS/EsDgtDNIh6/xqKqtOWv5jdRUFPZWkO+XVbjdou4q3HWlW+Ae51t3TtrIa76GuEnwvOzMh0H/v8tpHgUCzVkxoMjr4+aoVYrXt14xE2DaE09fpU6bu6zyQR98g/NeMtwn4BWd1qVTIdtUuxyrmqKXppC+EuaVqwaB0JWf/MdR1bee3BGkyg1O/cXQy0ELoqunWpmN2gXHsSQ9in3Nzw8tw12kVfO2T5ArfrUMcTugloz23loqvBysGQyGXDhQtzAdK1+DCG49hW7SGNvECq7LP+X+ccFYbb/jqNiSH3wO7ySu0hjz/RWS06pkqtnJ2/eQubp211cBqdYm1dg1Auypr13V0g6hK1TLp91PvXLGTc46pxgxrAGHZwvsSSTYVkfdrWYGKji8jcG8DkqTHgRu3GIq8xQI4RhmiYrs3XhAOd4W1f0qfaNrJtKJbqTgxpEaTZwG9sP3Eec+OD6PYoF1CPAmv1oOSC07NoOXvMIitdcUnL4Krr4DQugFGLs6gSqv7JdbAaeQXVl3qV4Gg9zGDrWki/cmFQrzIL5hiROhMpqeQvnn2Fde3O/V6BYikCxP5xCUvPreP3m204dPgv0dEehUUB3DLbdP0qlXasxVITXHBRv+FFvX8kGpSyEhXu5QJQcsGNzDSqZSRqMGYpCrS1o/3gEcR7uiMilOB0wUnKFXutaceRrXJMCAmMnxvYdMGFuUQKCTsUSySUOFxwCrN+FSz1ahrvgiOEhAMZUlAwH97xX8oF1+pB8aSpUCyREOJTXFItPFpai68ZWXCEkJZh8/oEJg/PR2rRZ9J4KJZI+HB1wWkqlzegC44Q4sSSGr864ly1gZBtQrFEQoeXC05TXsNrDhN0wTUWsZh0gDX6CGkuZraXdNFRKJEaQLFEwoVe28mrsJygs0stOpxHni64hhKkSKrHJzFDkUUIaVEolkioCORW23/cXPpEsXsXnLiRxzG2m2VptsHm9bGWrULtWUXYB5HeHY8Po3pZS0IICScUSyREBM1sK5cQqIkLTsZI7ew4QghsT2SpPu6JNGZThC49awisK6nquxNCSEj5hlhNt9mNIKTlELE7ajmQvbLMhywAeC+DxZ2sEyfP1yoyNVzWghBCGgUreBOyTaRowDwKhX7MxIexWuvjhjIg1SznkLhwHOtGOwcWPHbr3aGYIoSQEEOxRPY2O7B4lBc33m7UkUhnHpA1XxyCyGjH8EIK84Xy+yK2yVwRPFW5CHPJqpVonKVGuSozxzoQO90aawsSQkitoFgiexozuyuD4+uGADnlFYJcI1GiCmimznTKGjCTK85dhuNGG4ZMMWVWK88b/0QBTkMs9VjaLBgaaZhLq3Se6EIjhOxBKJbIHqbsWuroSdZ9ZXsz0y+FkZ4OxCwrfksL0s0+p/tqfxIjQ5PIL5gFONM9MZSrm3vVoSpK1yCuFGpYsdhynuBcvLQCuuEIIRGEYonsXSyupfqjs+BGbFlwRWQvApkb7gJDFuBcyJpr4Y3GENPVzb3qUMntKczXcmkH23mSbki64QghewiKJbJnqXAtWbLbnNTADVdywQmppGKXrG64U3FM6udDlpimnhQyvVljX+GKS+H+VZWB51VbqieNQmEX7XSBLjhCyF6HYonsUSpdS6bIqJ8bruyCE6/MpRiSGzmMnVpGn68QUzWlVvLIXp1AQgos4zgNK1VgnqfUFbrWCCF7FxalJHuTXRSirH7sGVuFbvdClFJABQjS7jg9AlnScSUvq5t7FrQU35vOoaZ1yJVbr3+3bj0pRlljiRDSmlAskT3Jztc4q87mg9VKQaNdcBUB2SJWKe+5WHAlMfQP6ecJZF50/4zo0+6XfrGyidxV45gXUnug2jghhHjDCt6E1BQzHmntTC2z0SwZaJ7ZZiILbg5dNay7tKuK3YQQEiFoWSKkpqxjbSWBrs4aHnIjhzldLuCMh3AR7rJaWMpEHFVcLHyrqolTKBFCCC1LhIQWe4aeTw0jYQWa694ba9QRQkijoWWJkFbAt9ijyFirU7A6IYQQWpYIaXmEBepqF2OLCCGkTrDOEiEtTREzciHgNIUSIYTUCbrhCGlFZC0nEYhd2ww4QgghTuiGI4QQQgjxgZYlQgghhBAfKJYIIYQQQnygWCKEEEII8YFiiRBCCCHEB4olQgghhBAfKJYIIYQQQnygWCKEEEII8YFiiRBCCCHEB4olQgghhBAfKJYIIYQQQnz4f4R4CS6Kq2uJAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 223,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#MultiLabelMarginLoss\n",
    "from IPython.display import Image\n",
    "Image(\"./imgs/1.png\")\n",
    "#input : x --> (N, c), y --> (N, c)其中y是 LongTensor, 且其元素为类别的index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.257142543792725"
      ]
     },
     "execution_count": 257,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = torch.nn.MultiLabelMarginLoss()\n",
    "x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8,1.1,4,7]])\n",
    "y = torch.LongTensor([[5, 4,3,0,-1,1,2]])\n",
    "loss(x, y).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25.700000000000003"
      ]
     },
     "execution_count": 258,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#5,4,3,0 为正确类别索引  1,2,6为错误类别索引\n",
    "(1-(4-0.2))+(1-(1.1-0.2))+(1-(0.8-0.2))+(1-(0.1-0.2))\n",
    "(1-(4-0.4))+(1-(1.1-0.4))+(1-(0.8-0.4))+(1-(0.1-0.4))\n",
    "(1-(4-7))+(1-(1.1-7))+(1-(0.8-7))+(1-(0.1-7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.257142857142857"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(0+0.1+0.4+1.1+0+0.3+0.6+1.3+4+6.9+7.2+7.9)/7"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
